package com.nulldev.util.internal.backport.httpclient_rw.impl.websocket;

import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;

import com.nulldev.util.data.Charsets.CharsetUtil;
import com.nulldev.util.internal.backport.httpclient_rw.impl.common.Logger;
import com.nulldev.util.internal.backport.httpclient_rw.impl.common.Utils;
import com.nulldev.util.internal.backport.httpclient_rw.impl.websocket.Frame.Opcode;

import static com.nulldev.util.internal.backport.concurrency9.Objects.requireNonNull;
import static java.lang.String.format;
import static com.nulldev.util.internal.backport.httpclient_rw.impl.common.Utils.dump;
import static com.nulldev.util.internal.backport.httpclient_rw.impl.websocket.StatusCodes.*;

/*
 * Consumes frame parts and notifies a message consumer, when there is
 * sufficient data to produce a message, or part thereof.
 *
 * Data consumed but not yet translated is accumulated until it's sufficient to
 * form a message.
 */
/* Exposed for testing purposes */
class MessageDecoder implements Frame.Consumer {

	private static final Logger debug = Utils.getWebSocketLogger("[Input]"::toString, Utils.DEBUG_WS);

	private final MessageStreamConsumer output;
	private final UTF8AccumulatingDecoder decoder = new UTF8AccumulatingDecoder();
	private boolean fin;
	private Opcode opcode, originatingOpcode;
	private long payloadLen;
	private long unconsumedPayloadLen;
	private ByteBuffer binaryData;

	MessageDecoder(MessageStreamConsumer output) {
		this.output = requireNonNull(output);
	}

	/* Exposed for testing purposes */
	MessageStreamConsumer getOutput() {
		return output;
	}

	@Override
	public void fin(boolean value) {
		if (debug.on()) {
			debug.log("fin %s", value);
		}
		fin = value;
	}

	@Override
	public void rsv1(boolean value) {
		if (debug.on()) {
			debug.log("rsv1 %s", value);
		}
		if (value) {
			throw new FailWebSocketException("Unexpected rsv1 bit");
		}
	}

	@Override
	public void rsv2(boolean value) {
		if (debug.on()) {
			debug.log("rsv2 %s", value);
		}
		if (value) {
			throw new FailWebSocketException("Unexpected rsv2 bit");
		}
	}

	@Override
	public void rsv3(boolean value) {
		if (debug.on()) {
			debug.log("rsv3 %s", value);
		}
		if (value) {
			throw new FailWebSocketException("Unexpected rsv3 bit");
		}
	}

	@Override
	public void opcode(Opcode v) {
		if (debug.on()) {
			debug.log("opcode %s", v);
		}
		if (v == Opcode.PING || v == Opcode.PONG || v == Opcode.CLOSE) {
			if (!fin) {
				throw new FailWebSocketException("Fragmented control frame  " + v);
			}
			opcode = v;
		} else if (v == Opcode.TEXT || v == Opcode.BINARY) {
			if (originatingOpcode != null) {
				throw new FailWebSocketException(format("Unexpected frame %s (fin=%s)", v, fin));
			}
			opcode = v;
			if (!fin) {
				originatingOpcode = v;
			}
		} else if (v == Opcode.CONTINUATION) {
			if (originatingOpcode == null) {
				throw new FailWebSocketException(format("Unexpected frame %s (fin=%s)", v, fin));
			}
			opcode = v;
		} else {
			throw new FailWebSocketException("Unexpected opcode " + v);
		}
	}

	@Override
	public void mask(boolean value) {
		if (debug.on()) {
			debug.log("mask %s", value);
		}
		if (value) {
			throw new FailWebSocketException("Masked frame received");
		}
	}

	@Override
	public void payloadLen(long value) {
		if (debug.on()) {
			debug.log("payloadLen %s", value);
		}
		if (opcode.isControl()) {
			if (value > Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH) {
				throw new FailWebSocketException(format("%s's payload length %s", opcode, value));
			}
			assert Opcode.CLOSE.isControl();
			if (opcode == Opcode.CLOSE && value == 1) {
				throw new FailWebSocketException("Incomplete status code");
			}
		}
		payloadLen = value;
		unconsumedPayloadLen = value;
	}

	@Override
	public void maskingKey(int value) {
		// `MessageDecoder.mask(boolean)` is where a masked frame is detected and
		// reported on; `MessageDecoder.mask(boolean)` MUST be invoked before
		// this method;
		// So this method (`maskingKey`) is not supposed to be invoked while
		// reading a frame that has came from the server. If this method is
		// invoked, then it's an error in implementation, thus InternalError
		throw new InternalError();
	}

	@Override
	public void payloadData(ByteBuffer data) {
		if (debug.on()) {
			debug.log("payload %s", data);
		}
		unconsumedPayloadLen -= data.remaining();
		boolean lastPayloadChunk = unconsumedPayloadLen == 0;
		if (opcode.isControl()) {
			if (binaryData != null) { // An intermediate or the last chunk
				binaryData.put(data);
			} else if (!lastPayloadChunk) { // The first chunk
				int remaining = data.remaining();
				// It shouldn't be 125, otherwise the next chunk will be of size
				// 0, which is not what Reader promises to deliver (eager
				// reading)
				assert remaining < Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH : dump(remaining);
				binaryData = ByteBuffer.allocate(Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH).put(data);
			} else { // The only chunk
				binaryData = ByteBuffer.allocate(data.remaining()).put(data);
			}
		} else {
			boolean last = fin && lastPayloadChunk;
			boolean text = opcode == Opcode.TEXT || originatingOpcode == Opcode.TEXT;
			if (!text) {
				output.onBinary(data.slice(), last);
				data.position(data.limit()); // Consume
			} else {
				boolean binaryNonEmpty = data.hasRemaining();
				CharBuffer textData;
				try {
					textData = decoder.decode(data, last);
				} catch (CharacterCodingException e) {
					throw new FailWebSocketException("Invalid UTF-8 in frame " + opcode, StatusCodes.NOT_CONSISTENT).initCause(e);
				}
				if (!(binaryNonEmpty && !textData.hasRemaining())) {
					// If there's a binary data, that result in no text, then we
					// don't deliver anything, otherwise:
					output.onText(textData, last);
				}
			}
		}
	}

	@Override
	public void endFrame() {
		if (debug.on()) {
			debug.log("end frame");
		}
		if (opcode.isControl()) {
			binaryData.flip();
		}
		switch (opcode) {
			case CLOSE:
				char statusCode = NO_STATUS_CODE;
				String reason = "";
				if (payloadLen != 0) {
					int len = binaryData.remaining();
					assert 2 <= len && len <= Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH : dump(len, payloadLen);
					statusCode = binaryData.getChar();
					if (!isLegalToReceiveFromServer(statusCode)) {
						throw new FailWebSocketException("Illegal status code: " + statusCode);
					}
					try {
						reason = CharsetUtil.UTF_8.newDecoder().decode(binaryData).toString();
					} catch (CharacterCodingException e) {
						throw new FailWebSocketException("Illegal close reason").initCause(e);
					}
				}
				output.onClose(statusCode, reason);
				break;
			case PING:
				output.onPing(binaryData);
				binaryData = null;
				break;
			case PONG:
				output.onPong(binaryData);
				binaryData = null;
				break;
			default:
				assert opcode == Opcode.TEXT || opcode == Opcode.BINARY || opcode == Opcode.CONTINUATION : dump(opcode);
				if (fin) {
					// It is always the last chunk:
					// either TEXT(FIN=TRUE)/BINARY(FIN=TRUE) or CONT(FIN=TRUE)
					originatingOpcode = null;
				}
				break;
		}
		payloadLen = 0;
		opcode = null;
	}
}
